%  FILE:   cnn_train_dag_hardmine.m
% 
%    This function works with the get batch function and trains the detection
%    network with hard negative mining.
% 
%  INPUT:  imagePaths (image paths of a batch of images)
%          imageSizes (image sizes of the same batch of images)
%          labelRects (ground truth bounding boxes)
% 
%  OUTPUT: images (500x500 random cropped regions)
%          clsmaps (ground truth classification heat map)
%          regmaps (ground truth regression heat map)

function [net,stats] = cnn_train_dag_hardmine(net, imdb, getBatch, varargin)
%CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper
%    CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with
%    the DagNN wrapper instead of the SimpleNN wrapper.

% Copyright (C) 2014-16 Andrea Vedaldi.
% All rights reserved.
%
% This file is part of the VLFeat library and is made available under
% the terms of the BSD license (see the COPYING file).

opts.expDir = fullfile('data','exp') ;
opts.continue = true ;
opts.batchSize = 256 ;
opts.numSubBatches = 1 ;
opts.train = [] ;
opts.val = [] ;
opts.gpus = [] ;
opts.prefetch = false ;
opts.numEpochs = 300 ;
opts.learningRate = 0.001 ;
opts.weightDecay = 0.0005 ;
opts.momentum = 0.9 ;
opts.randomSeed = 0 ;
opts.memoryMapFile = fullfile(tempdir, 'matconvnet.bin') ;
opts.profile = false ;
opts.snapshotIter = inf;

opts.sampleSize = 256;
opts.posFraction = 0.5;

opts.keepDilatedZeros = false;
opts.derOutputs = {'objective', 1} ;
opts.extractStatsFn = @extractStats ;
opts.plotStatistics = false;
opts = vl_argparse(opts, varargin) ;

%opts.snapshotIter = 200 * (10/opts.batchSize); 
opts = vl_argparse(opts, varargin) ;

if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end
if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end
if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end
if isnan(opts.train), opts.train = [] ; end
if isnan(opts.val), opts.val = []; end

% -------------------------------------------------------------------------
%                                                            Initialization
% -------------------------------------------------------------------------

evaluateMode = isempty(opts.train) ;
if ~evaluateMode
  if isempty(opts.derOutputs)
    error('DEROUTPUTS must be specified when training.\n') ;
  end
end

state.getBatch = getBatch ;
stats = [] ;

% -------------------------------------------------------------------------
%                                                        Train and validate
% -------------------------------------------------------------------------

%modelPath = @(ep,it) fullfile(opts.expDir, ...
%                              sprintf('net-epoch-%d-it%d.mat',ep,it));
modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat',ep));
modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ;

%if ~opts.continue 
%    lastEpoch = 0; %lastIter = 0;
%else
%    %[lastEpoch, lastIter] = findLastCheckpoint(opts.expDir) ;
%    lastEpoch = findLastCheckpoint(opts.expDir) ;
%end
%
%if lastEpoch >= 1 && lastIter >= 0
%  fprintf('%s: resuming by loading epoch %d iter %d\n', mfilename, ...
%          lastEpoch, lastIter) ;
%  [net, stats] = loadState(modelPath(lastEpoch, lastIter)) ;
%end
%
%if lastEpoch == 0 || lastIter == ceil(numel(opts.train)/opts.batchSize)
%    startEpoch = lastEpoch+1;
%    lastIter = 0;
%    fprintf('Start training new epoch %d from iter %d\n', ...
%            startEpoch, lastIter+1);
%else
%    startEpoch = lastEpoch;
%    fprintf('Continue training epoch %d from iter %d\n', ...
%            startEpoch, lastIter+1);
%end

start = opts.continue * findLastCheckpoint(opts.expDir) ;
if start >= 1
  fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ;
  [net, stats] = loadState(modelPath(start)) ;
end

% check if loss layer is DetLoss
if ~isa(net.layers(net.getLayerIndex('loss_cls')).block, 'dagnn.DetLoss')
    net.removeLayer('loss_cls');
    net.addLayer('loss_cls', dagnn.DetLoss('loss', 'logistic'), ...
                 {'score_cls', 'label_cls'}, 'loss_cls');
    disp('Start using dagnn.DetLoss for loss');
end

% check if we have an extra variable for spatial loss map
%if ~ismember(net.layers(net.getLayerIndex('loss_cls')).outputs, 'loss_cls_map')
%    net.setLayerOutputs('loss_cls', {'loss_cls', 'loss_cls_map'});
%end

for epoch=start+1:opts.numEpochs
  % Set the random seed based on the epoch and opts.randomSeed.
  % This is important for reproducibility, including when training
  % is restarted from a checkpoint.

  rng(epoch + opts.randomSeed) ;
  prepareGPUs(opts, epoch == start+1) ;

  % Train for one epoch.
  state.epoch = epoch ;
%  if epoch == startEpoch
%      state.iter = lastIter;
%  else
%      state.iter = 0 ; 
%  end

  state.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ;
  state.train = opts.train(randperm(numel(opts.train))) ; % shuffle
  state.val = opts.val(randperm(numel(opts.val))) ;
  state.imdb = imdb ;

  if numel(opts.gpus) <= 1
    [stats.train(epoch),prof] = process_epoch(net, state, opts, 'train') ;
    stats.val(epoch) = process_epoch(net, state, opts, 'val') ;
    if opts.profile
      profview(0,prof) ;
      keyboard ;
    end
  else
    savedNet = net.saveobj() ;
    spmd
      net_ = dagnn.DagNN.loadobj(savedNet) ;
      [stats_.train, prof_] = process_epoch(net_, state, opts, 'train') ;
      stats_.val = process_epoch(net_, state, opts, 'val') ;
      if labindex == 1, savedNet_ = net_.saveobj() ; end
    end
    net = dagnn.DagNN.loadobj(savedNet_{1}) ;
    stats__ = accumulateStats(stats_) ;
    stats.train(epoch) = stats__.train ;
    stats.val(epoch) = stats__.val ;
    if opts.profile
      mpiprofile('viewer', [prof_{:,1}]) ;
      keyboard ;
    end
    clear net_ stats_ stats__ savedNet savedNet_ ;
  end

  % save 
  if ~evaluateMode
   saveState(modelPath(epoch), net, stats) ;
  end
  % NOTE we saved in the process_epoch

  if opts.plotStatistics
    switchFigure(1) ; clf ;
    plots = setdiff(...
      cat(2,...
      fieldnames(stats.train)', ...
      fieldnames(stats.val)'), {'num', 'time'}) ;
    for p = plots
      p = char(p) ;
      values = zeros(0, epoch) ;
      leg = {} ;
      for f = {'train', 'val'}
        f = char(f) ;
        if isfield(stats.(f), p)
          tmp = [stats.(f).(p)] ;
          values(end+1,:) = tmp(1,:)' ;
          leg{end+1} = f ;
        end
      end
      subplot(1,numel(plots),find(strcmp(p,plots))) ;
      plot(1:epoch, values','o-') ;
      xlabel('epoch') ;
      title(p) ;
      legend(leg{:}) ;
      grid on ;
    end
    drawnow ;
    print(1, modelFigPath, '-dpdf') ;
  end
end

% -------------------------------------------------------------------------
function [stats, prof] = process_epoch(net, state, opts, mode)
% -------------------------------------------------------------------------

% initialize empty momentum
if strcmp(mode,'train')
  state.momentum = num2cell(zeros(1, numel(net.params))) ;
end

% move CNN  to GPU as needed
numGpus = numel(opts.gpus) ;
if numGpus >= 1
  net.move('gpu') ;
  if strcmp(mode,'train')
    state.momentum = cellfun(@gpuArray,state.momentum,'UniformOutput',false) ;
  end
end
if numGpus > 1
  mmap = map_gradients(opts.memoryMapFile, net, numGpus) ;
else
  mmap = [] ;
end

% profile
if opts.profile
  if numGpus <= 1
    profile clear ;
    profile on ;
  else
    mpiprofile reset ;
    mpiprofile on ;
  end
end

subset = state.(mode) ;
num = 0 ;
stats.num = 0 ; % return something even if subset = []
stats.time = 0 ;
adjustTime = 0 ;

start = tic ;

% NOTE a man wants to keep the batch size consistent
%lastIter = state.iter;
%lastIndex = lastIter*opts.batchSize; 
%for t=lastIndex+1:opts.batchSize:numel(subset)
for t = 1:opts.batchSize:numel(subset) 
  fprintf('%s: epoch %02d: %3d/%3d:', mode, state.epoch, ...
          fix((t-1)/opts.batchSize)+1, ceil(numel(subset)/opts.batchSize)) ;
  batchSize = min(opts.batchSize, numel(subset) - t + 1) ;

  for s=1:opts.numSubBatches
    % get this image batch and prefetch the next
    batchStart = t + (labindex-1) + (s-1) * numlabs ;
    batchEnd = min(t+opts.batchSize-1, numel(subset)) ;
    batch = subset(batchStart : opts.numSubBatches * numlabs : batchEnd) ;
    num = num + numel(batch) ;
    if numel(batch) == 0, continue ; end

    inputs = state.getBatch(state.imdb, batch) ;

    if opts.prefetch
      if s == opts.numSubBatches
        batchStart = t + (labindex-1) + opts.batchSize ;
        batchEnd = min(t+2*opts.batchSize-1, numel(subset)) ;
      else
        batchStart = batchStart + numlabs ;
      end
      nextBatch = subset(batchStart : opts.numSubBatches * numlabs : batchEnd) ;
      state.getBatch(state.imdb, nextBatch) ;
    end

    if strcmp(mode, 'train') 
        net.mode = 'normal' ;
        net.accumulateParamDers = (s ~= 1) ;

        % forward pass
        net.forward(inputs, opts.derOutputs);

        % NOTE hard example selection (change cls label variable) no need to
        % change reg label because it listens to pos cls label
        loss_cls_map = net.layers(net.getLayerIndex('loss_cls')).block.loss_map;
        label_cls = net.vars(net.getVarIndex('label_cls')).value;

        % poor man's version to ensure diversity and difficulty
        label_cls(loss_cls_map<0.03) = 0;
        pos_num = 0; neg_num = 0;
        for i = 1:size(label_cls,4)
            clsmap = label_cls(:,:,:,i);
            %lossmap = loss_cls_map(:,:,:,i);
            
            pos_maxnum = opts.sampleSize*opts.posFraction;
            pos_idx = find(clsmap(:)==1);
            pos_num = pos_num + numel(pos_idx);
            if numel(pos_idx) > pos_maxnum
                didx = Shuffle(numel(pos_idx), 'index', numel(pos_idx)-pos_maxnum);
                clsmap(pos_idx(didx)) = 0;
            end
            
            neg_maxnum = pos_maxnum*(1-opts.posFraction)/opts.posFraction;
            neg_idx = find(clsmap(:)==-1);
            neg_num = neg_num + numel(neg_idx);
            if numel(neg_idx) > neg_maxnum
                ridx = Shuffle(numel(neg_idx), 'index', gather(neg_maxnum));
                didx = [1:numel(neg_idx)];
                didx(ridx) = [];
                clsmap(neg_idx(didx)) = 0;
            end
            %fprintf('%d, #pos: %d, avgloss: %f, #neg: %d, avgloss: %f\n', i, ...
            %        sum(clsmap(:)>0), sum(lossmap(clsmap>0)), ...
            %        sum(clsmap(:)<0), sum(lossmap(clsmap<0)));
            
            label_cls(:,:,:,i) = clsmap;
        end
        fprintf(' (+ %d) (- %d) ', pos_num, neg_num);
        %loss_pos = loss_cls_map(label_cls>0);
        %loss_neg = loss_cls_map(label_cls<0);
        %loss_mid = loss_cls_map(label_cls==0);
        %fprintf('\n#pos: %d, sum: %f, max: %f, #0: %d, sum: %f, max: %f, #neg: %d, sum: %f, max: %f\n', ...
        %        sum(label_cls(:)>0), sum(loss_pos), max(loss_pos), ...
        %        sum(label_cls(:)==0), sum(loss_mid), max(loss_mid), ...
        %        sum(label_cls(:)<0), sum(loss_neg), max(loss_neg));

        net.vars(net.getVarIndex('label_cls')).value = label_cls;

        % backward pass
        net.backward(inputs, opts.derOutputs);
    else
        error('do not use this function for testing');
        %net.mode = 'test' ;
        %net.eval(inputs) ;
    end
  end

  % accumulate gradient
  if strcmp(mode, 'train')
    if ~isempty(mmap)
      write_gradients(mmap, net) ;
      labBarrier() ;
    end
    state = accumulate_gradients(state, net, opts, batchSize, mmap) ;
  end

  % get statistics
  time = toc(start) + adjustTime ;
  batchTime = time - stats.time ;
  stats = opts.extractStatsFn(net) ;
  stats.num = num ;
  stats.time = time ;
  currentSpeed = batchSize / batchTime ;
  averageSpeed = (t + batchSize - 1) / time ;
  %averageSpeed = (t - lastIndex + batchSize - 1) / time ;
  if t == opts.batchSize + 1
    % compensate for the first iteration, which is an outlier
    adjustTime = 2*batchTime - time ;
    stats.time = time + adjustTime ;
  end

  fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ;
  for f = setdiff(fieldnames(stats)', {'num', 'time'})
    f = char(f) ;
    fprintf(' %s:', f) ;
    fprintf(' %.6f', stats.(f)) ;
  end
  fprintf('\n') ;
  
  %if strcmp(mode, 'train')
  %    iter = fix((t-1)/opts.batchSize)+1; 
  %    if mod(iter, opts.snapshotIter) == 0 || iter==ceil(numel(subset)/opts.batchSize)
  %        path_ = fullfile(opts.expDir, sprintf('net-epoch-%d-it%d.mat', ...
  %                                              state.epoch, iter));
  %        saveState(path_, net, stats);
  %        net.move('gpu');
  %        fprintf('Epoch %d, Iteration %d: Saving model snapshot to %s.\n',...
  %                state.epoch, iter, path_);
  %    end
  %end
end

if ~isempty(mmap)
  unmap_gradients(mmap) ;
end

if opts.profile
  if numGpus <= 1
    prof = profile('info') ;
    profile off ;
  else
    prof = mpiprofile('info');
    mpiprofile off ;
  end
else
  prof = [] ;
end

net.reset() ;
net.move('cpu') ;

% -------------------------------------------------------------------------
function state = accumulate_gradients(state, net, opts, batchSize, mmap)
% -------------------------------------------------------------------------
numGpus = numel(opts.gpus) ;
otherGpus = setdiff(1:numGpus, labindex) ;

for p=1:numel(net.params)

  % accumualte gradients from multiple labs (GPUs) if needed
  if numGpus > 1
    tag = net.params(p).name ;
    for g = otherGpus
      tmp = gpuArray(mmap.Data(g).(tag)) ;
      net.params(p).der = net.params(p).der + tmp ;
    end
  end

  % zero out gradient in dilated regions
  if opts.keepDilatedZeros 
      % only dilated conv in resnet-50 will have 5x5 filter size
      if size(net.params(p).der,1)==5 || size(net.params(p).der,2)==5
          net.params(p).der(2:2:4,:,:,:) = 0;
          net.params(p).der(:,2:2:4,:,:) = 0;
      end
  end

  switch net.params(p).trainMethod

    case 'average' % mainly for batch normalization
      thisLR = net.params(p).learningRate ;
      net.params(p).value = ...
          (1 - thisLR) * net.params(p).value + ...
          (thisLR/batchSize/net.params(p).fanout) * net.params(p).der ;

    case 'gradient'
      thisDecay = opts.weightDecay * net.params(p).weightDecay ;
      thisLR = state.learningRate * net.params(p).learningRate ;
      state.momentum{p} = opts.momentum * state.momentum{p} ...
        - thisDecay * net.params(p).value ...
        - (1 / batchSize) * net.params(p).der ;
      net.params(p).value = net.params(p).value + thisLR * state.momentum{p} ;

    case 'otherwise'
      error('Unknown training method ''%s'' for parameter ''%s''.', ...
        net.params(p).trainMethod, ...
        net.params(p).name) ;
  end
end

% -------------------------------------------------------------------------
function mmap = map_gradients(fname, net, numGpus)
% -------------------------------------------------------------------------
format = {} ;
for i=1:numel(net.params)
  format(end+1,1:3) = {'single', size(net.params(i).value), net.params(i).name} ;
end
format(end+1,1:3) = {'double', [3 1], 'errors'} ;
if ~exist(fname) && (labindex == 1)
  f = fopen(fname,'wb') ;
  for g=1:numGpus
    for i=1:size(format,1)
      fwrite(f,zeros(format{i,2},format{i,1}),format{i,1}) ;
    end
  end
  fclose(f) ;
end
labBarrier() ;
mmap = memmapfile(fname, ...
                  'Format', format, ...
                  'Repeat', numGpus, ...
                  'Writable', true) ;

% -------------------------------------------------------------------------
function write_gradients(mmap, net)
% -------------------------------------------------------------------------
for i=1:numel(net.params)
  mmap.Data(labindex).(net.params(i).name) = gather(net.params(i).der) ;
end

% -------------------------------------------------------------------------
function unmap_gradients(mmap)
% -------------------------------------------------------------------------

% -------------------------------------------------------------------------
function stats = accumulateStats(stats_)
% -------------------------------------------------------------------------

for s = {'train', 'val'}
  s = char(s) ;
  total = 0 ;

  % initialize stats stucture with same fields and same order as
  % stats_{1}
  stats__ = stats_{1} ;
  names = fieldnames(stats__.(s))' ;
  values = zeros(1, numel(names)) ;
  fields = cat(1, names, num2cell(values)) ;
  stats.(s) = struct(fields{:}) ;

  for g = 1:numel(stats_)
    stats__ = stats_{g} ;
    num__ = stats__.(s).num ;
    total = total + num__ ;

    for f = setdiff(fieldnames(stats__.(s))', 'num')
      f = char(f) ;
      stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ;

      if g == numel(stats_)
        stats.(s).(f) = stats.(s).(f) / total ;
      end
    end
  end
  stats.(s).num = total ;
end

% -------------------------------------------------------------------------
function stats = extractStats(net)
% -------------------------------------------------------------------------
sel = find(cellfun(@(x)(isa(x,'dagnn.HuberLoss')||isa(x,'dagnn.DetLoss')), ...
                   {net.layers.block})) ;
stats = struct() ;
for i = 1:numel(sel)
  stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ;
end

% -------------------------------------------------------------------------
function saveState(fileName, net, stats)
% -------------------------------------------------------------------------
net_ = net ;
net = net_.saveobj() ;
save(fileName, 'net', 'stats') ;

% -------------------------------------------------------------------------
function [net, stats] = loadState(fileName)
% -------------------------------------------------------------------------
load(fileName, 'net', 'stats') ;
net = dagnn.DagNN.loadobj(net) ;

% -------------------------------------------------------------------------
function epoch = findLastCheckpoint(modelDir)
% -------------------------------------------------------------------------
list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
epoch = max([epoch 0]) ;

%% -------------------------------------------------------------------------
%function [epoch, iter] = findLastCheckpoint(modelDir)
%% -------------------------------------------------------------------------
%list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
%tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
%%tokens = regexp({list.name}, 'net-epoch-([\d]+)-it([\d]+).mat', 'tokens') ;
%epoch = 0; %iter = 0;
%
%% find latest epoch
%for i = 1:numel(tokens)
%    token = tokens{i}{1}; 
%    ep = str2num(token{1});
%    if ep >= epoch 
%        epoch = ep; 
%    end
%end

% find latest iteration in that epoch
%for i = 1:numel(tokens)
%    token = tokens{i}{1}; 
%    ep = str2num(token{1});
%    it = str2num(token{2});
%    if ep == epoch && it > iter
%        iter = it;        
%    end
%end

% -------------------------------------------------------------------------
function switchFigure(n)
% -------------------------------------------------------------------------
if get(0,'CurrentFigure') ~= n
  try
    set(0,'CurrentFigure',n) ;
  catch
    figure(n) ;
  end
end

% -------------------------------------------------------------------------
function prepareGPUs(opts, cold)
% -------------------------------------------------------------------------
numGpus = numel(opts.gpus) ;
if numGpus > 1
  % check parallel pool integrity as it could have timed out
  pool = gcp('nocreate') ;
  if ~isempty(pool) && pool.NumWorkers ~= numGpus
    delete(pool) ;
  end
  pool = gcp('nocreate') ;
  if isempty(pool)
    parpool('local', numGpus) ;
    cold = true ;
  end
  if exist(opts.memoryMapFile)
    delete(opts.memoryMapFile) ;
  end
end
if numGpus >= 1 && cold
  fprintf('%s: resetting GPU\n', mfilename)
  if numGpus == 1
    gpuDevice(opts.gpus);
  else
    spmd, gpuDevice(opts.gpus(labindex)), end
  end
end
